from sklearn.ensemble import RandomForestClassifier
import pandas as pd
import joblib

class TransactionClassifier:
    def __init__(self):
        self.model = RandomForestClassifier()
        self.labels = {
            0: "餐饮", 
            1: "交通",
            2: "娱乐",
            3: "购物"
        }
    
    def train(self, data_path: str):
        """使用历史数据训练模型"""
        df = pd.read_csv(data_path)
        X = df[['amount', 'time', 'merchant']]
        y = df['category']
        self.model.fit(X, y)
        joblib.dump(self.model, 'model.pkl')
    
    def predict(self, amount: float, time: str, merchant: str) -> str:
        """预测交易类型"""
        model = joblib.load('model.pkl')
        prediction = model.predict([[amount, time, merchant]])
        return self.labels[prediction[0]]